{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "DATASET = 'nyt'\n",
    "ARCH = '2p'\n",
    "\n",
    "tr = pickle.load(open('Dataset/dataset_mul-rel/%s/tr.pkl'%(DATASET), 'rb'))\n",
    "vl = pickle.load(open('Dataset/dataset_mul-rel/%s/vl.pkl'%(DATASET), 'rb'))\n",
    "ts = pickle.load(open('Dataset/dataset_mul-rel/%s/ts.pkl'%(DATASET), 'rb'))\n",
    "\n",
    "print(len(tr), len(vl), len(ts))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pre_tr = []\n",
    "for i in range(6):\n",
    "    tmp = pickle.load(open('Dataset/dataset_mul-rel/%s/pre_tr-%d.pkl'%(DATASET, i), 'rb'))\n",
    "    pre_tr += tmp\n",
    "pre_vl = pickle.load(open('Dataset/dataset_mul-rel/%s/pre_vl.pkl'%(DATASET), 'rb'))\n",
    "pre_ts = pickle.load(open('Dataset/dataset_mul-rel/%s/pre_ts.pkl'%(DATASET), 'rb'))\n",
    "\n",
    "print(len(pre_tr), len(pre_vl), len(pre_ts))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_REL = dict()\n",
    "MXL = 0\n",
    "\n",
    "for d in tr:\n",
    "    sent = d[0]\n",
    "    MXL = max(len(sent)+4, MXL)\n",
    "    \n",
    "    rels = d[1]\n",
    "    \n",
    "    for rel in rels:\n",
    "        rel = rel[2]\n",
    "        \n",
    "        if not rel in NUM_REL:\n",
    "            NUM_REL[rel] = 1\n",
    "\n",
    "print(NUM_REL)\n",
    "            \n",
    "NUM_REL = len(NUM_REL) + 1 # 0 for NA\n",
    "print('Num of relation: %d' % (NUM_REL))\n",
    "print('Max length: %d' % (MXL))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import spacy\n",
    "\n",
    "NLP = spacy.load('en_core_web_lg')\n",
    "\n",
    "POS = dict()\n",
    "for pos in list(NLP.tagger.labels):\n",
    "    POS[pos] = len(POS)+1\n",
    "\n",
    "NUM_POS = len(POS) + 1 # 0 for NA\n",
    "print(POS)\n",
    "print('Num of pos: %d' % (NUM_POS))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch as T\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "\n",
    "import numpy as np\n",
    "from tqdm import tqdm_notebook as tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DS(Dataset):\n",
    "    def __init__(self, dat):\n",
    "        super(DS, self).__init__()\n",
    "        \n",
    "        self.dat = dat\n",
    "        \n",
    "    def __len__(self):\n",
    "        return len(self.dat)\n",
    "    \n",
    "    def __getitem__(self, idx):\n",
    "        return self.dat[idx]\n",
    "    \n",
    "ld_tr = DataLoader(DS(pre_tr), batch_size=32, shuffle=True)\n",
    "ld_vl = DataLoader(DS(pre_vl), batch_size=64)\n",
    "ld_ts = DataLoader(DS(pre_ts), batch_size=64)\n",
    "\n",
    "for idx, inp, pos, dep_fw, dep_bw, ans_ne, wgt_ne, ans_rel, wgt_rel in ld_ts:\n",
    "    print(idx.shape)\n",
    "    print(inp.shape, pos.shape, dep_fw.shape, dep_bw.shape)\n",
    "    print(ans_ne.shape, wgt_ne.shape)\n",
    "    print(ans_rel.shape, wgt_rel.shape)\n",
    "    \n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "\n",
    "class GCN(nn.Module):\n",
    "    def __init__(self, hid_size=256):\n",
    "        super(GCN, self).__init__()\n",
    "        \n",
    "        self.hid_size = hid_size\n",
    "        \n",
    "        self.W = nn.Parameter(T.FloatTensor(self.hid_size, self.hid_size//2).cuda())\n",
    "        self.b = nn.Parameter(T.FloatTensor(self.hid_size//2, ).cuda())\n",
    "        \n",
    "        self.init()\n",
    "    \n",
    "    def init(self):\n",
    "        stdv = 1/math.sqrt(self.hid_size//2)\n",
    "        \n",
    "        self.W.data.uniform_(-stdv, stdv)\n",
    "        self.b.data.uniform_(-stdv, stdv)\n",
    "    \n",
    "    def forward(self, inp, adj, is_relu=True):\n",
    "        out = T.matmul(inp, self.W)+self.b\n",
    "        out = T.matmul(adj, out)\n",
    "        \n",
    "        if is_relu==True:\n",
    "            out = nn.functional.relu(out)\n",
    "        \n",
    "        return out\n",
    "    \n",
    "    def __repr__(self):\n",
    "        return self.__class__.__name__+'(hid_size=%d)'%(self.hid_size)\n",
    "\n",
    "gcn = GCN().cuda()\n",
    "print(gcn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model_GraphRel(nn.Module):\n",
    "    def __init__(self, mxl, num_rel, \n",
    "                 hid_size=256, rnn_layer=2, gcn_layer=2, dp=0.5):\n",
    "        super(Model_GraphRel, self).__init__()\n",
    "        \n",
    "        self.mxl = mxl\n",
    "        self.num_rel = num_rel\n",
    "        self.hid_size = hid_size\n",
    "        self.rnn_layer = rnn_layer\n",
    "        self.gcn_layer = gcn_layer\n",
    "        self.dp = dp\n",
    "        \n",
    "        self.emb_pos = nn.Embedding(NUM_POS, 15)\n",
    "        \n",
    "        self.rnn = nn.GRU(300+15, self.hid_size, \n",
    "                          num_layers=self.rnn_layer, batch_first=True, dropout=dp, bidirectional=True)\n",
    "        self.gcn_fw = nn.ModuleList([GCN(self.hid_size*2) for _ in range(self.gcn_layer)])\n",
    "        self.gcn_bw = nn.ModuleList([GCN(self.hid_size*2) for _ in range(self.gcn_layer)])\n",
    "        \n",
    "        self.rnn_ne = nn.GRU(self.hid_size*2, self.hid_size, \n",
    "                             batch_first=True)\n",
    "        self.fc_ne = nn.Linear(self.hid_size, 5)\n",
    "        \n",
    "        self.trs0_rel = nn.Linear(self.hid_size*2, self.hid_size)\n",
    "        self.trs1_rel = nn.Linear(self.hid_size*2, self.hid_size)\n",
    "        self.fc_rel = nn.Linear(self.hid_size*2, self.num_rel)\n",
    "        \n",
    "        if ARCH=='2p':\n",
    "            self.gcn2p_fw = nn.ModuleList([GCN(self.hid_size*2) for _ in range(self.num_rel)])\n",
    "            self.gcn2p_bw = nn.ModuleList([GCN(self.hid_size*2) for _ in range(self.num_rel)])\n",
    "        \n",
    "        self.dp = nn.Dropout(dp)\n",
    "    \n",
    "    def output(self, feat):\n",
    "        out_ne, _ = self.rnn_ne(feat)\n",
    "        out_ne = self.dp(out_ne)\n",
    "        out_ne = self.fc_ne(out_ne)\n",
    "        \n",
    "        trs0 = nn.functional.relu(self.trs0_rel(feat))\n",
    "        trs0 = self.dp(trs0)\n",
    "        trs1 = nn.functional.relu(self.trs1_rel(feat))\n",
    "        trs1 = self.dp(trs1)\n",
    "        \n",
    "        trs0 = trs0.view((trs0.shape[0], trs0.shape[1], 1, trs0.shape[2]))\n",
    "        trs0 = trs0.expand((trs0.shape[0], trs0.shape[1], trs0.shape[1], trs0.shape[3]))\n",
    "        trs1 = trs1.view((trs1.shape[0], 1, trs1.shape[1], trs1.shape[2]))\n",
    "        trs1 = trs1.expand((trs1.shape[0], trs1.shape[2], trs1.shape[2], trs1.shape[3]))\n",
    "        trs = T.cat([trs0, trs1], dim=3)\n",
    "        \n",
    "        out_rel = self.fc_rel(trs)\n",
    "        \n",
    "        return out_ne, out_rel\n",
    "    \n",
    "    def forward(self, inp, pos, dep_fw, dep_bw):\n",
    "        pos = self.emb_pos(pos)\n",
    "        inp = T.cat([inp, pos], dim=2)\n",
    "        inp = self.dp(inp)\n",
    "        \n",
    "        out, _ = self.rnn(inp)\n",
    "        \n",
    "        for i in range(self.gcn_layer):\n",
    "            out_fw = self.gcn_fw[i](out, dep_fw)\n",
    "            out_bw = self.gcn_bw[i](out, dep_bw)\n",
    "            \n",
    "            out = T.cat([out_fw, out_bw], dim=2)\n",
    "            out = self.dp(out)\n",
    "        \n",
    "        feat_1p = out\n",
    "        out_ne, out_rel = self.output(feat_1p)\n",
    "        \n",
    "        if ARCH=='1p':\n",
    "            return out_ne, out_rel\n",
    "        \n",
    "        # 2p\n",
    "        out_ne1, out_rel1 = out_ne, out_rel\n",
    "        \n",
    "        dep_fw = nn.functional.softmax(out_rel, dim=3)\n",
    "        dep_bw = dep_fw.transpose(1, 2)\n",
    "        \n",
    "        outs = []\n",
    "        for i in range(self.num_rel):\n",
    "            out_fw = self.gcn2p_fw[i](feat_1p, dep_fw[:, :, :, i])\n",
    "            out_bw = self.gcn2p_bw[i](feat_1p, dep_bw[:, :, :, i])\n",
    "            \n",
    "            outs.append(self.dp(T.cat([out_fw, out_bw], dim=2)))\n",
    "        \n",
    "        feat_2p = feat_1p\n",
    "        for i in range(self.num_rel):\n",
    "            feat_2p = feat_2p+outs[i]\n",
    "        \n",
    "        out_ne2, out_rel2 = self.output(feat_2p)\n",
    "        \n",
    "        return out_ne1, out_rel1, out_ne2, out_rel2\n",
    "\n",
    "model = nn.DataParallel(Model_GraphRel(mxl=MXL, num_rel=NUM_REL)).cuda()\n",
    "\n",
    "if ARCH=='1p':\n",
    "    out_ne, out_rel = model(inp.cuda(), pos.cuda(), dep_fw.cuda(), dep_bw.cuda())\n",
    "    print(out_ne1.shape, out_rel1.shape)\n",
    "    \n",
    "else:\n",
    "    out_ne1, out_rel1, out_ne2, out_rel2 = model(inp.cuda(), pos.cuda(), dep_fw.cuda(), dep_bw.cuda())\n",
    "    print(out_ne1.shape, out_rel1.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def post_proc(dat, idx, out_ne, out_rel):    \n",
    "    out_ne = np.argmax(out_ne.detach().cpu().numpy(), axis=1)\n",
    "    out_rel = np.argmax(out_rel.detach().cpu().numpy(), axis=2)\n",
    "    \n",
    "    nes = dict()\n",
    "    el = -1\n",
    "    for i, v in enumerate(out_ne):\n",
    "        if v==4:\n",
    "            nes[i] = [i, i]\n",
    "            el = -1\n",
    "            \n",
    "        elif v==1:\n",
    "            el = i\n",
    "            \n",
    "        elif v==3:\n",
    "            if not el==-1:\n",
    "                for p in range(el, i+1):\n",
    "                    nes[p] = [el, i]\n",
    "                \n",
    "        elif v==2:\n",
    "            pass\n",
    "        \n",
    "        elif v==0:\n",
    "            el = -1\n",
    "    \n",
    "    rels = []\n",
    "    for i in range(MXL):\n",
    "        for j in range(MXL):\n",
    "            if not out_rel[i][j]==0 and i in nes and j in nes:\n",
    "                rels.append([nes[i][1], nes[j][1], out_rel[i][j]])\n",
    "    \n",
    "    cl = []\n",
    "    for rel in rels:\n",
    "        if not rel in cl:\n",
    "            cl.append(rel)\n",
    "    rels = cl\n",
    "    \n",
    "    ans = []\n",
    "    for tmp in dat[idx][1]:\n",
    "        ans.append([tmp[0][1], tmp[1][1], tmp[2]])\n",
    "    \n",
    "    cl = []\n",
    "    for rel in ans:\n",
    "        if not rel in cl:\n",
    "            cl.append(rel)\n",
    "    ans = cl\n",
    "    \n",
    "    return rels, ans\n",
    "\n",
    "class F1:\n",
    "    def __init__(self):\n",
    "        self.P = [0, 0]\n",
    "        self.R = [0, 0]\n",
    "    \n",
    "    def get(self):\n",
    "        try:\n",
    "            P = self.P[0]/self.P[1]\n",
    "        except:\n",
    "            P = 0\n",
    "        \n",
    "        try:\n",
    "            R = self.R[0]/self.R[1]\n",
    "        except:\n",
    "            R = 0\n",
    "            \n",
    "        try: \n",
    "            F = 2*P*R/(P+R)\n",
    "        except:\n",
    "            F = 0\n",
    "        \n",
    "        return P, R, F\n",
    "    \n",
    "    def add(self, ro, ra):\n",
    "        self.P[1] += len(ro)\n",
    "        self.R[1] += len(ra)\n",
    "        \n",
    "        for r in ro:\n",
    "            if r in ra:\n",
    "                self.P[0] += 1\n",
    "        \n",
    "        for r in ra:\n",
    "            if r in ro:\n",
    "                self.R[0] += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "EPOCHS = 200\n",
    "LR = 0.0008\n",
    "DECAY = 0.98\n",
    "\n",
    "W_NE = 2\n",
    "W_REL = 2\n",
    "ALP = 3\n",
    "\n",
    "loss_func = nn.CrossEntropyLoss(reduction='none').cuda()\n",
    "optim = T.optim.Adam(model.parameters(), lr=LR)\n",
    "\n",
    "def ls(out_ne, wgt_ne, out_rel, wgt_rel):\n",
    "    ls_ne = loss_func(out_ne.view((-1, 5)), ans_ne.view((-1, )).cuda()).view(ans_ne.shape)\n",
    "    ls_ne = (ls_ne*wgt_ne.cuda()).sum() / (wgt_ne>0).sum().cuda()\n",
    "    \n",
    "    ls_rel = loss_func(out_rel.view((-1, NUM_REL)), ans_rel.view((-1, )).cuda()).view(ans_rel.shape)\n",
    "    ls_rel = (ls_rel*wgt_rel.cuda()).sum() / (wgt_rel>0).sum().cuda()\n",
    "    \n",
    "    return ls_ne, ls_rel\n",
    "\n",
    "for e in tqdm(range(EPOCHS)):\n",
    "    ls_ep_ne1, ls_ep_rel1 = 0, 0\n",
    "    ls_ep_ne2, ls_ep_rel2 = 0, 0\n",
    "    \n",
    "    model.train()\n",
    "    with tqdm(ld_tr) as TQ:\n",
    "        for i, (idx, inp, pos, dep_fw, dep_bw, ans_ne, wgt_ne, ans_rel, wgt_rel) in enumerate(TQ):\n",
    "            \n",
    "            wgt_ne.masked_fill_(wgt_ne==1, W_NE)\n",
    "            wgt_ne.masked_fill_(wgt_ne==0, 1)\n",
    "            wgt_ne.masked_fill_(wgt_ne==-1, 0)\n",
    "            \n",
    "            wgt_rel.masked_fill_(wgt_rel==1, W_REL)\n",
    "            wgt_rel.masked_fill_(wgt_rel==0, 1)\n",
    "            wgt_rel.masked_fill_(wgt_rel==-1, 0)\n",
    "            \n",
    "            out_ne1, out_rel1, out_ne2, out_rel2 = model(inp.cuda(), pos.cuda(), dep_fw.cuda(), dep_bw.cuda())\n",
    "            \n",
    "            ls_ne1, ls_rel1 = ls(out_ne1, wgt_ne, out_rel1, wgt_rel)\n",
    "            ls_ne2, ls_rel2 = ls(out_ne2, wgt_ne, out_rel2, wgt_rel)\n",
    "            \n",
    "            optim.zero_grad()\n",
    "            ((ls_ne1+ls_rel1) + ALP*(ls_ne2+ls_rel2)).backward()\n",
    "            optim.step()\n",
    "            \n",
    "            ls_ne1 = ls_ne1.detach().cpu().numpy()\n",
    "            ls_rel1 = ls_rel1.detach().cpu().numpy()\n",
    "            ls_ep_ne1 += ls_ne1\n",
    "            ls_ep_rel1 += ls_rel1\n",
    "            \n",
    "            ls_ne2 = ls_ne2.detach().cpu().numpy()\n",
    "            ls_rel2 = ls_rel2.detach().cpu().numpy()\n",
    "            ls_ep_ne2 += ls_ne2\n",
    "            ls_ep_rel2 += ls_rel2\n",
    "            \n",
    "            TQ.set_postfix(ls_ne1='%.3f'%(ls_ne1), ls_rel1='%.3f'%(ls_rel1), \n",
    "                           ls_ne2='%.3f'%(ls_ne2), ls_rel2='%.3f'%(ls_rel2))\n",
    "            \n",
    "            if i%100==0:\n",
    "                for pg in optim.param_groups:\n",
    "                    pg['lr'] *= DECAY\n",
    "            \n",
    "        ls_ep_ne1 /= len(TQ)\n",
    "        ls_ep_rel1 /= len(TQ)\n",
    "        \n",
    "        ls_ep_ne2 /= len(TQ)\n",
    "        ls_ep_rel2 /= len(TQ)\n",
    "        \n",
    "        print('Ep %d: ne1: %.4f, rel1: %.4f, ne2: %.4f, rel2: %.4f' % (e+1, ls_ep_ne1, ls_ep_rel1, \n",
    "                                                                       ls_ep_ne2, ls_ep_rel2))\n",
    "        T.save(model.state_dict(), 'Model/%s_%s_%d.pt' % (DATASET, ARCH, e+1))\n",
    "    \n",
    "    f1 = F1()\n",
    "    model.eval()\n",
    "    with tqdm(ld_vl) as TQ:\n",
    "        for idx, inp, pos, dep_fw, dep_bw, ans_ne, wgt_ne, ans_rel, wgt_rel in TQ:\n",
    "            _, _, out_ne, out_rel = model(inp.cuda(), pos.cuda(), dep_fw.cuda(), dep_bw.cuda())\n",
    "            \n",
    "            for i in range(idx.shape[0]):\n",
    "                rels, ans = post_proc(vl, idx[i], out_ne[i], out_rel[i])\n",
    "                f1.add(rels, ans)\n",
    "        \n",
    "        p, r, f = f1.get()\n",
    "        print('P: %.4f%%, R: %.4f%%, F: %.4f%%' % (100*p, 100*r, 100*f))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# model.load_state_dict('')\n",
    "\n",
    "f1 = F1()\n",
    "model.eval()\n",
    "with tqdm(ld_ts) as TQ:\n",
    "    for idx, inp, pos, dep_fw, dep_bw, ans_ne, wgt_ne, ans_rel, wgt_rel in TQ:\n",
    "        _, _, out_ne, out_rel = model(inp.cuda(), pos.cuda(), dep_fw.cuda(), dep_bw.cuda())\n",
    "            \n",
    "        for i in range(idx.shape[0]):\n",
    "            rels, ans = post_proc(ts, idx[i], out_ne[i], out_rel[i])\n",
    "            f1.add(rels, ans)\n",
    "        \n",
    "    p, r, f = f1.get()\n",
    "    print('Test: P: %.4f%%, R: %.4f%%, F: %.4f%%' % (100*p, 100*r, 100*f))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python3.6",
   "language": "python",
   "name": "python36"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
